import gradio as gr
import cv2
import numpy as np
from PIL import Image
import pickle

# 加载模型
with open('best_model.pkl', 'rb') as file:  # 替换为您的模型文件路径
    model = pickle.load(file)

# 预处理图像
def preprocess_image(image):
    # 将 PIL 图像对象转换为 NumPy 数组
    img = np.array(image)
    # 将 RGB 转换为灰度图像
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    # 调整图片大小为 32x32 像素
    img = cv2.resize(img, (32, 32))
    # 展平图像为 1 维数组
    img = img.flatten()
    return img

# 预测图像
def predict_image(image):
    processed_img = preprocess_image(image)
    # 根据模型实际使用的方法进行调整
    prediction = model.predict([processed_img])  # 注意这里使用列表包裹数组
    # 假设 0 代表猫，1 代表狗
    return '猫' if prediction[0] == 0 else '狗'

# 创建 Gradio 界面
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.components.Image(type="pil"),  # 定义输入类型为图像
    outputs='text',
    title="猫狗图像分类器",
    description="上传一张图像来预测它是猫还是狗。"
)

# 启动界面
iface.launch()